-
Notifications
You must be signed in to change notification settings - Fork 12.4k
CUDA: 4D FlashAttention support #14628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: 4D FlashAttention support #14628
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are passing on RTX 2060
f23950a
to
ab82dc2
Compare
326e4e2
to
2f9b295
Compare
There was some issue with the WMMA kernel (which is now fixed), merge when convenient for you. |
c43f275
into
ggml-org:gg/llama-high-throughput
* CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel
Something is wrong, I'm getting a ton of failures on 3090Ti (CUDA 12.9):
|
You are testing master. This wa merged in another brabch |
Ah, LOL, sorry. :) Why is master failing though? |
If master is failing, can you do a git bisect to determine since when? |
Its failing the mask->ne[2] != 1 tests. These are not relevant |
* kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (#14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This PR adds 4-dimensional CUDA FlashAttention support for #14363 . The data layout for the fixup was changed but there should be no change to performance. As discussed in #14505 (comment) , the CUDA code requires
mask->ne[2] == 1
, otherwise it would require additional complexity to ensure that the GQA-specific optimizations infattn-mma-f16.cuh
produce correct results.